from autocoder.common import (
    AutoCoderArgs,
    TranslateArgs,
    TranslateReadme,
    split_code_into_segments,
    SourceCode,
)
from autocoder.common.buildin_tokenizer import BuildinTokenizer
from autocoder.pyproject import PyProject, Level1PyProject
from autocoder.tsproject import TSProject
from autocoder.suffixproject import SuffixProject
from autocoder.index.entry import build_index_and_filter_files
from autocoder.common.code_auto_merge import CodeAutoMerge
from autocoder.common.code_auto_merge_diff import CodeAutoMergeDiff
from autocoder.common.code_auto_merge_strict_diff import CodeAutoMergeStrictDiff
from autocoder.common.code_auto_merge_editblock import CodeAutoMergeEditBlock
from autocoder.common.code_auto_generate import CodeAutoGenerate
from autocoder.common.code_auto_generate_diff import CodeAutoGenerateDiff
from autocoder.common.code_auto_generate_strict_diff import CodeAutoGenerateStrictDiff
from autocoder.common.code_auto_generate_editblock import CodeAutoGenerateEditBlock
from typing import Optional, Generator
import byzerllm
import os
from autocoder.common.image_to_page import ImageToPage, ImageToPageDirectly
from autocoder.utils.conversation_store import store_code_model_conversation
from loguru import logger
import time
from autocoder.common.printer import Printer
from autocoder.utils.llms import get_llm_names
from autocoder.privacy.model_filter import ModelPathFilter
from autocoder.common import SourceCodeList
from autocoder.common.global_cancel import global_cancel
from autocoder.events.event_manager_singleton import get_event_manager
from autocoder.events import event_content as EventContentCreator
from autocoder.events.event_types import EventMetadata
from autocoder.common.v2.code_editblock_manager import CodeEditBlockManager


class BaseAction:
    def _get_content_length(self, content: str) -> int:
        try:
            tokenizer = BuildinTokenizer()
            return tokenizer.count_tokens(content)
        except Exception as e:
            logger.warning(
                f"Failed to use tokenizer to count tokens, fallback to len(): {e}")
            return len(content)


class ActionTSProject(BaseAction):
    def __init__(
        self, args: AutoCoderArgs, llm: Optional[byzerllm.ByzerLLM] = None
    ) -> None:
        self.args = args
        self.llm = llm
        self.pp = None
        self.printer = Printer()

    def run(self):
        args = self.args
        if args.project_type != "ts":
            return False
        pp = TSProject(args=args, llm=self.llm)
        self.pp = pp
        pp.run()

        # source_code = pp.output()
        source_code_list = SourceCodeList(pp.sources)
        if self.llm:
            if args.in_code_apply:
                old_query = args.query
                args.query = (args.context or "") + "\n\n" + args.query
            source_code_list = build_index_and_filter_files(
                llm=self.llm, args=args, sources=pp.sources
            )
            if args.in_code_apply:
                args.query = old_query

        if args.image_file:
            if args.image_mode == "iterative":
                image_to_page = ImageToPage(llm=self.llm, args=args)
            else:
                image_to_page = ImageToPageDirectly(llm=self.llm, args=args)

            file_name = os.path.splitext(os.path.basename(args.image_file))[0]
            html_path = os.path.join(
                os.path.dirname(args.image_file), "html", f"{file_name}.html"
            )
            image_to_page.run_then_iterate(
                origin_image=args.image_file,
                html_path=html_path,
                max_iter=self.args.image_max_iter,
            )
            html_code = ""
            with open(html_path, "r", encoding="utf-8") as f:
                html_code = f.read()

            source_code_list.sources.append(SourceCode(
                module_name=html_path,
                source_code=html_code,
                tag="IMAGE"))

        self.process_content(source_code_list)
        return True

    def process_content(self, source_code_list: SourceCodeList):
        args = self.args
        content = source_code_list.to_str()
        if args.execute and self.llm and not args.human_as_model:
            content_length = self._get_content_length(content)
            if content_length > self.args.model_max_input_length:
                logger.warning(
                    f"Content(send to model) is {content_length} tokens, which is larger than the maximum input length {self.args.model_max_input_length}"
                )

        global_cancel.check_and_raise(token=self.args.event_file)

        if (args.enable_auto_fix_merge or args.enable_auto_fix_lint) and args.execute and args.auto_merge=="editblock":
            code_merge_manager = CodeEditBlockManager(llm=self.llm, args=self.args,action=self)
            code_merge_manager.run(query=args.query, source_code_list=source_code_list)
            return

        if args.execute:
            self.printer.print_in_terminal("code_generation_start")
            start_time = time.time()
            if args.auto_merge == "diff":
                generate = CodeAutoGenerateDiff(
                    llm=self.llm, args=self.args, action=self
                )
            elif args.auto_merge == "strict_diff":
                generate = CodeAutoGenerateStrictDiff(
                    llm=self.llm, args=self.args, action=self
                )
            elif args.auto_merge == "editblock":
                generate = CodeAutoGenerateEditBlock(
                    llm=self.llm, args=self.args, action=self
                )
            else:
                generate = CodeAutoGenerate(
                    llm=self.llm, args=self.args, action=self)
            
            generate_result = generate.single_round_run(
                query=args.query, source_code_list=source_code_list
            )
            elapsed_time = time.time() - start_time
            speed = generate_result.metadata.get(
                'generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
            input_tokens_cost = generate_result.metadata.get(
                'input_tokens_cost', 0)
            generated_tokens_cost = generate_result.metadata.get(
                'generated_tokens_cost', 0)
            model_names = ",".join(get_llm_names(generate.llms))
            self.printer.print_in_terminal(
                "code_generation_complete",
                duration=elapsed_time,
                input_tokens=generate_result.metadata.get(
                    'input_tokens_count', 0),
                output_tokens=generate_result.metadata.get(
                    'generated_tokens_count', 0),
                input_cost=input_tokens_cost,
                output_cost=generated_tokens_cost,
                speed=round(speed, 2),
                model_names=model_names,
                sampling_count=len(generate_result.contents)
            )

            get_event_manager(self.args.event_file).write_result(
                EventContentCreator.create_result(content=EventContentCreator.ResultTokenStatContent(
                    model_name=model_names,
                    elapsed_time=elapsed_time,
                    input_tokens=generate_result.metadata.get(
                        'input_tokens_count', 0),
                    output_tokens=generate_result.metadata.get(
                        'generated_tokens_count', 0),
                    input_cost=input_tokens_cost,
                    output_cost=generated_tokens_cost,
                    speed=round(speed, 2)
                )).to_dict(),metadata=EventMetadata(
                    action_file=self.args.file
                ).to_dict())

            global_cancel.check_and_raise(token=self.args.event_file)

            merge_result = None
            if args.execute and args.auto_merge:
                self.printer.print_in_terminal("code_merge_start")
                if args.auto_merge == "diff":
                    code_merge = CodeAutoMergeDiff(
                        llm=self.llm, args=self.args)
                    merge_result = code_merge.merge_code(
                        generate_result=generate_result)
                elif args.auto_merge == "strict_diff":
                    code_merge = CodeAutoMergeStrictDiff(
                        llm=self.llm, args=self.args)
                    merge_result = code_merge.merge_code(
                        generate_result=generate_result)
                elif args.auto_merge == "editblock":
                    code_merge = CodeAutoMergeEditBlock(
                        llm=self.llm, args=self.args)
                    merge_result = code_merge.merge_code(
                        generate_result=generate_result)
                else:
                    code_merge = CodeAutoMerge(llm=self.llm, args=self.args)
                    merge_result = code_merge.merge_code(
                        generate_result=generate_result)

                if merge_result is not None:
                    content = merge_result.contents[0]
                    store_code_model_conversation(
                        args=self.args,
                        instruction=self.args.query,
                        conversations=merge_result.conversations[0],
                        model=self.llm.default_model_name,
                    )
                else:
                    content = generate_result.contents[0]
                    store_code_model_conversation(
                        args=self.args,
                        instruction=self.args.query,
                        conversations=generate_result.conversations[0],
                        model=self.llm.default_model_name,
                    )


class ActionPyProject(BaseAction):
    def __init__(
        self, args: AutoCoderArgs, llm: Optional[byzerllm.ByzerLLM] = None
    ) -> None:
        self.args = args
        self.llm = llm
        self.pp = None
        self.printer = Printer()

    def run(self):
        args = self.args
        if args.project_type != "py":
            return False
        pp = PyProject(args=self.args, llm=self.llm)
        self.pp = pp
        pp.run(packages=args.py_packages.split(
            ",") if args.py_packages else [])
        source_code_list = SourceCodeList(pp.sources)

        if self.llm:
            old_query = args.query
            if args.in_code_apply:
                args.query = (args.context or "") + "\n\n" + args.query
            source_code_list = build_index_and_filter_files(
                llm=self.llm, args=args, sources=pp.sources
            )
            if args.in_code_apply:
                args.query = old_query

        self.process_content(source_code_list)
        return True

    def process_content(self, source_code_list: SourceCodeList):
        args = self.args
        content = source_code_list.to_str()
        if args.execute and self.llm and not args.human_as_model:
            content_length = self._get_content_length(content)
            if content_length > self.args.model_max_input_length:
                self.printer.print_in_terminal(
                    "code_execution_warning",
                    style="yellow",
                    content_length=content_length,
                    max_length=self.args.model_max_input_length
                )

        global_cancel.check_and_raise(token=self.args.event_file)

        if (args.enable_auto_fix_merge or args.enable_auto_fix_lint) and args.execute and args.auto_merge=="editblock":
            code_merge_manager = CodeEditBlockManager(llm=self.llm, args=self.args,action=self)
            code_merge_manager.run(query=args.query, source_code_list=source_code_list)
            return

        if args.execute:
            self.printer.print_in_terminal("code_generation_start")
            start_time = time.time()
            if args.auto_merge == "diff":
                generate = CodeAutoGenerateDiff(
                    llm=self.llm, args=self.args, action=self
                )
            elif args.auto_merge == "strict_diff":
                generate = CodeAutoGenerateStrictDiff(
                    llm=self.llm, args=self.args, action=self
                )
            elif args.auto_merge == "editblock":
                generate = CodeAutoGenerateEditBlock(
                    llm=self.llm, args=self.args, action=self
                )
            else:
                generate = CodeAutoGenerate(
                    llm=self.llm, args=self.args, action=self)
            
            generate_result = generate.single_round_run(
                query=args.query, source_code_list=source_code_list
            )

            elapsed_time = time.time() - start_time
            speed = generate_result.metadata.get(
                'generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
            model_names = ",".join(get_llm_names(generate.llms))
            input_tokens_cost = generate_result.metadata.get(
                'input_tokens_cost', 0)
            generated_tokens_cost = generate_result.metadata.get(
                'generated_tokens_cost', 0)
            self.printer.print_in_terminal(
                "code_generation_complete",
                duration=elapsed_time,
                input_tokens=generate_result.metadata.get(
                    'input_tokens_count', 0),
                output_tokens=generate_result.metadata.get(
                    'generated_tokens_count', 0),
                input_cost=input_tokens_cost,
                output_cost=generated_tokens_cost,
                speed=round(speed, 2),
                model_names=model_names,
                sampling_count=len(generate_result.contents)
            )

            get_event_manager(self.args.event_file).write_result(
                EventContentCreator.create_result(content=EventContentCreator.ResultTokenStatContent(
                    model_name=model_names,
                    elapsed_time=elapsed_time,
                    input_tokens=generate_result.metadata.get(
                        'input_tokens_count', 0),
                    output_tokens=generate_result.metadata.get(
                        'generated_tokens_count', 0),
                    input_cost=input_tokens_cost,
                    output_cost=generated_tokens_cost,
                    speed=round(speed, 2)
                )).to_dict(), metadata=EventMetadata(
                    action_file=self.args.file
                ).to_dict())

            global_cancel.check_and_raise(token=self.args.event_file)

            merge_result = None
            if args.execute and args.auto_merge:
                self.printer.print_in_terminal("code_merge_start")
                if args.auto_merge == "diff":
                    code_merge = CodeAutoMergeDiff(
                        llm=self.llm, args=self.args)
                    merge_result = code_merge.merge_code(
                        generate_result=generate_result)
                elif args.auto_merge == "strict_diff":
                    code_merge = CodeAutoMergeStrictDiff(
                        llm=self.llm, args=self.args)
                    merge_result = code_merge.merge_code(
                        generate_result=generate_result)
                elif args.auto_merge == "editblock":
                    code_merge = CodeAutoMergeEditBlock(
                        llm=self.llm, args=self.args)
                    merge_result = code_merge.merge_code(
                        generate_result=generate_result)
                else:
                    code_merge = CodeAutoMerge(llm=self.llm, args=self.args)
                    merge_result = code_merge.merge_code(
                        generate_result=generate_result)

                content = merge_result.contents[0]

                store_code_model_conversation(
                    args=self.args,
                    instruction=self.args.query,
                    conversations=merge_result.conversations[0],
                    model=self.llm.default_model_name,
                )
            else:
                content = generate_result.contents[0]

                store_code_model_conversation(
                    args=self.args,
                    instruction=self.args.query,
                    conversations=generate_result.conversations[0],
                    model=self.llm.default_model_name,
                )


class ActionSuffixProject(BaseAction):
    def __init__(
        self, args: AutoCoderArgs, llm: Optional[byzerllm.ByzerLLM] = None
    ) -> None:
        self.args = args
        self.llm = llm
        self.pp = None
        self.printer = Printer()

    def run(self):
        args = self.args
        pp = SuffixProject(args=args, llm=self.llm)
        self.pp = pp
        pp.run()
        source_code_list = SourceCodeList(pp.sources)
        if self.llm:
            if args.in_code_apply:
                old_query = args.query
                args.query = (args.context or "") + "\n\n" + args.query
            source_code_list = build_index_and_filter_files(
                llm=self.llm, args=args, sources=pp.sources
            )
            if args.in_code_apply:
                args.query = old_query
        self.process_content(source_code_list)

    def process_content(self, source_code_list: SourceCodeList):
        args = self.args
        content = source_code_list.to_str()

        if args.execute and self.llm and not args.human_as_model:
            content_length = self._get_content_length(content)
            if content_length > self.args.model_max_input_length:
                logger.warning(
                    f"Content(send to model) is {content_length} tokens, which is larger than the maximum input length {self.args.model_max_input_length}"
                )

        global_cancel.check_and_raise(token=self.args.event_file)

        if (args.enable_auto_fix_merge or args.enable_auto_fix_lint) and args.execute and args.auto_merge=="editblock":
            code_merge_manager = CodeEditBlockManager(llm=self.llm, args=self.args,action=self)
            code_merge_manager.run(query=args.query, source_code_list=source_code_list)
            return

        if args.execute:
            self.printer.print_in_terminal("code_generation_start")
            start_time = time.time()
            if args.auto_merge == "diff":
                generate = CodeAutoGenerateDiff(
                    llm=self.llm, args=self.args, action=self
                )
            elif args.auto_merge == "strict_diff":
                generate = CodeAutoGenerateStrictDiff(
                    llm=self.llm, args=self.args, action=self
                )
            elif args.auto_merge == "editblock":
                generate = CodeAutoGenerateEditBlock(
                    llm=self.llm, args=self.args, action=self
                )
            else:
                generate = CodeAutoGenerate(
                    llm=self.llm, args=self.args, action=self)
            
            generate_result = generate.single_round_run(
                query=args.query, source_code_list=source_code_list
            )

        elapsed_time = time.time() - start_time
        speed = generate_result.metadata.get(
            'generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
        model_names = ",".join(get_llm_names(generate.llms))
        input_tokens_cost = generate_result.metadata.get(
            'input_tokens_cost', 0)
        generated_tokens_cost = generate_result.metadata.get(
            'generated_tokens_cost', 0)
        self.printer.print_in_terminal(
            "code_generation_complete",
            duration=elapsed_time,
            input_tokens=generate_result.metadata.get('input_tokens_count', 0),
            output_tokens=generate_result.metadata.get(
                'generated_tokens_count', 0),
            input_cost=input_tokens_cost,
            output_cost=generated_tokens_cost,
            speed=round(speed, 2),
            model_names=model_names,
            sampling_count=len(generate_result.contents)
        )

        get_event_manager(self.args.event_file).write_result(
                EventContentCreator.create_result(content=EventContentCreator.ResultTokenStatContent(
                    model_name=model_names,
                    elapsed_time=elapsed_time,
                    input_tokens=generate_result.metadata.get(
                        'input_tokens_count', 0),
                    output_tokens=generate_result.metadata.get(
                        'generated_tokens_count', 0),
                    input_cost=input_tokens_cost,
                    output_cost=generated_tokens_cost,
                    speed=round(speed, 2)
                )).to_dict(), metadata=EventMetadata(
                    action_file=self.args.file
                ).to_dict())

        global_cancel.check_and_raise(token=self.args.event_file)

        merge_result = None
        if args.execute and args.auto_merge:
            self.printer.print_in_terminal("code_merge_start")
            if args.auto_merge == "diff":
                code_merge = CodeAutoMergeDiff(llm=self.llm, args=self.args)
                merge_result = code_merge.merge_code(
                    generate_result=generate_result)
            elif args.auto_merge == "strict_diff":
                code_merge = CodeAutoMergeStrictDiff(
                    llm=self.llm, args=self.args)
                merge_result = code_merge.merge_code(
                    generate_result=generate_result)
            elif args.auto_merge == "editblock":
                code_merge = CodeAutoMergeEditBlock(
                    llm=self.llm, args=self.args)
                merge_result = code_merge.merge_code(
                    generate_result=generate_result)
            else:
                code_merge = CodeAutoMerge(llm=self.llm, args=self.args)
                merge_result = code_merge.merge_code(
                    generate_result=generate_result)

        if merge_result is not None:
            content = merge_result.contents[0]
            store_code_model_conversation(
                args=self.args,
                instruction=self.args.query,
                conversations=merge_result.conversations[0],
                model=self.llm.default_model_name,
            )
        else:
            content = generate_result.contents[0]

            store_code_model_conversation(
                args=self.args,
                instruction=self.args.query,
                conversations=generate_result.conversations[0],
                model=self.llm.default_model_name,
            )
